import importlib.util
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


def load_shiftshare_module(root: Path):
    path = root / "analysis" / "19_iv_shiftshare_rollout_absorb.py"
    spec = importlib.util.spec_from_file_location("shiftshare", path)
    mod = importlib.util.module_from_spec(spec)
    assert spec.loader is not None
    spec.loader.exec_module(mod)
    return mod


def empirical_p_ge(x: np.ndarray, x_obs: float) -> float:
    x = x[np.isfinite(x)]
    if x.size == 0 or not np.isfinite(x_obs):
        return float("nan")
    return (int((x >= x_obs).sum()) + 1) / (x.size + 1)


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    fig_dir = root / "paper_joc" / "figures"
    fig_dir.mkdir(parents=True, exist_ok=True)

    perm_path = root / "outputs" / "iv_shiftshare_randomization_firststage.csv"
    if not perm_path.exists():
        raise FileNotFoundError(f"Missing permutation file: {perm_path}. Run analysis/25_iv_shiftshare_stress_tests.py first.")

    perm = pd.read_csv(perm_path)
    f_net = pd.to_numeric(perm["F_netusoft"], errors="coerce").to_numpy()
    f_int = pd.to_numeric(perm["F_netusoft_x_edu"], errors="coerce").to_numpy()

    mod = load_shiftshare_module(root)
    df = mod.build_dataset(root)
    controls = ["agea", "gndr", "hinctnta"]
    obs = mod.fit_iv_interaction_absorb(df, "y_part_index", controls=controls)
    diag = obs.first_stage.diagnostics
    f_obs_net = float(diag.loc["netusoft", "f.stat"])
    f_obs_int = float(diag.loc["netusoft_x_edu", "f.stat"])

    p_net = empirical_p_ge(f_net, f_obs_net)
    p_int = empirical_p_ge(f_int, f_obs_int)

    fig, axes = plt.subplots(1, 2, figsize=(11.5, 4.2))
    bins = 30

    axes[0].hist(f_net[np.isfinite(f_net)], bins=bins, color="#4c78a8", alpha=0.85, edgecolor="white")
    axes[0].axvline(f_obs_net, color="#d62728", lw=2)
    axes[0].set_title("Panel A. Timing-permutation distribution of F(netusoft)")
    axes[0].set_xlabel("First-stage F statistic")
    axes[0].set_ylabel("Count")
    axes[0].text(
        0.98,
        0.95,
        f"Observed F = {f_obs_net:.2f}\nEmpirical p = {p_net:.4f}",
        transform=axes[0].transAxes,
        ha="right",
        va="top",
        fontsize=9,
    )
    axes[0].grid(axis="y", color="#e6e6e6", lw=0.8)

    axes[1].hist(f_int[np.isfinite(f_int)], bins=bins, color="#f58518", alpha=0.85, edgecolor="white")
    axes[1].axvline(f_obs_int, color="#d62728", lw=2)
    axes[1].set_title("Panel B. Timing-permutation distribution of F(netusoft×edu)")
    axes[1].set_xlabel("First-stage F statistic")
    axes[1].set_ylabel("Count")
    axes[1].text(
        0.98,
        0.95,
        f"Observed F = {f_obs_int:.2f}\nEmpirical p = {p_int:.4f}",
        transform=axes[1].transAxes,
        ha="right",
        va="top",
        fontsize=9,
    )
    axes[1].grid(axis="y", color="#e6e6e6", lw=0.8)

    fig.tight_layout()
    out_path = fig_dir / "figS1_randomization_firststage.png"
    fig.savefig(out_path, dpi=220)
    plt.close(fig)
    print(f"Wrote: {out_path}")


if __name__ == "__main__":
    main()
